package deeplearn;

import org.deeplearning4j.core.storage.StatsStorage;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration.GraphBuilder;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.StackVertex;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.impl.NormalDistribution;
import org.deeplearning4j.zoo.model.VGG16;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;

import java.io.File;

/**
 * 说明：
 * 博客：https://blog.csdn.net/dong_lxkm/article/details/103125742
 *     1、dl4j并没有提供像keras那样冻结某些层参数的方法，这里采用设置learningrate为0的方法，来冻结某些层的参数
 *
 *     2、这个的更新器，用的是sgd，不能用其他的（比方说Adam、Rmsprop），因为这些自适应更新器会考虑前面batch的梯度作为本次更新的梯度，达不到不更新参数的目的
 *
 *     3、这里用了StackVertex，沿着第一维合并张量，也就是合并真实数据样本和Generator产生的数据样本，共同训练Discriminator
 *
 *     4、训练过程中多次update   Discriminator的参数，以便量出最大距离，让后更新Generator一次
 *
 *     5、进行10w次迭代
		6、数据被下载到C:\Users\Administrator\.deeplearning4j\data\MNIST
 */
public class Gan {

	static double lr = 0.01;
	static String model = "F:/face/gan.zip";
	public static void main(String[] args) throws Exception {
 
		final NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder().updater(new Sgd(lr))
				.weightInit(WeightInit.XAVIER);
 
		final GraphBuilder graphBuilder = builder.graphBuilder().backpropType(BackpropType.Standard)
				.addInputs("input1", "input2")
				.addLayer("g1",
						new DenseLayer.Builder().nIn(10).nOut(128).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"input1")
				.addLayer("g2",
						new DenseLayer.Builder().nIn(128).nOut(512).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"g1")
				.addLayer("g3",
						new DenseLayer.Builder().nIn(512).nOut(28 * 28).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"g2")
				.addVertex("stack", new StackVertex(), "input2", "g3")
				.addLayer("d1",
						new DenseLayer.Builder().nIn(28 * 28).nOut(256).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"stack")
				.addLayer("d2",
						new DenseLayer.Builder().nIn(256).nOut(128).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"d1")
				.addLayer("d3",
						new DenseLayer.Builder().nIn(128).nOut(128).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"d2")
				.addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(128).nOut(1)
						.activation(Activation.SIGMOID).build(), "d3")
				.setOutputs("out");
 
		ComputationGraph net = new ComputationGraph(graphBuilder.build());
		if (new File(model).exists()) {
			net = ComputationGraph.load(new File(model), true);
		}else{
			net = new ComputationGraph(graphBuilder.build());
		}
		net.init();
		System.out.println(net.summary());
		UIServer uiServer = UIServer.getInstance();
		StatsStorage statsStorage = new InMemoryStatsStorage();
		uiServer.attach(statsStorage);
		net.setListeners(new ScoreIterationListener(100));
		net.getLayers();



		DataSetIterator train = new MnistDataSetIterator(30, true, 12345);
		//按垂直方向（行顺序）堆叠数组构成一个新的数组
		INDArray labelD = Nd4j.vstack(Nd4j.ones(30, 1), Nd4j.zeros(30, 1));
 
		INDArray labelG = Nd4j.ones(60, 1);
 
		for (int i = 1; i <= 100000; i++) {
			if (!train.hasNext()) {
				train.reset();
			}
			INDArray trueExp = train.next().getFeatures();

			INDArray z = Nd4j.rand(new NormalDistribution(),new long[] { 30, 10 });
			MultiDataSet dataSetD = new MultiDataSet(new INDArray[] { z, trueExp },
					new INDArray[] { labelD });
			for(int m=0;m<10;m++){
				trainD(net, dataSetD);
			}
			z = Nd4j.rand(new NormalDistribution(),new long[] { 30, 10 });
			MultiDataSet dataSetG = new MultiDataSet(new INDArray[] { z, trueExp },
					new INDArray[] { labelG });
			trainG(net, dataSetG);
 
			if (i % 10000 == 0) {
			   net.save(new File(model), true);
			}
 
		}
 
	}
 	// 判别模型  D(x)
	public static void trainD(ComputationGraph net, MultiDataSet dataSet) {
		net.setLearningRate("g1", 0);
		net.setLearningRate("g2", 0);
		net.setLearningRate("g3", 0);
		net.setLearningRate("d1", lr);
		net.setLearningRate("d2", lr);
		net.setLearningRate("d3", lr);
		net.setLearningRate("out", lr);
		net.fit(dataSet);
	}
	//生成模型 g(z)
	public static void trainG(ComputationGraph net, MultiDataSet dataSet) {
		net.setLearningRate("g1", lr);
		net.setLearningRate("g2", lr);
		net.setLearningRate("g3", lr);
		net.setLearningRate("d1", 0);
		net.setLearningRate("d2", 0);
		net.setLearningRate("d3", 0);
		net.setLearningRate("out", 0);
		net.fit(dataSet);
	}
}